import torch
import onnxruntime

input_tensor = torch.randn([1, 3, 512, 512], requires_grad=False)
jit_model = torch.jit.load("v1_jit.pth")
with torch.no_grad():  # 在推理时，我们不需要计算梯度
    output = jit_model(input_tensor)
    print("jit shape", output.shape)

onnx_model = onnxruntime.InferenceSession("v1_onnx.onnx")
outputs = onnx_model.run(["output:0"], {"input:0": input_tensor.numpy()})
print("onnx shape", outputs[0].shape)